package moa.classifiers.JanStaniewicz.helpers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.stream.Collectors;

import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;

//import com.sun.org.apache.xerces.internal.util.ParserConfigurationSettings;

import com.yahoo.labs.samoa.instances.Instance;

public class Mahalanobis extends AbstractCluster implements OversamplingMethod, Serializable {

	double[][] sumxy;
	double invcov[][];

	double PostBalanceRatio;
	
	protected ArrayList<Instance> PastPositiveInstances;

	public ArrayList<Instance> getPastInstances(ArrayList<Instance> pastPositiveInstances) {
		return PastPositiveInstances;
	};

	double MinorityClass;

	/**
	 * default constructor for sake of serialisation frameworks
	 */
	public Mahalanobis(double minorityClass, double postBalanceRatio) {
		MinorityClass = minorityClass;
		this.PostBalanceRatio = postBalanceRatio;
		PastPositiveInstances = new ArrayList<Instance>();
	}

	/**
	 * constructs mahalanobis distance cluster
	 *
	 * @param dimensions
	 *            amount of dimensions in cluster
	 */
	public Mahalanobis(int dimensions) {
		super(dimensions);
		sumxy = new double[dimensions][dimensions];
	}

	/**
	 * convenience constructor to instantiate trained distance cluster
	 *
	 * @param mx
	 *            expectation walues
	 * @param invcov
	 *            inverse covariance matrix
	 */
	public Mahalanobis(double[] mx, double[][] invcov) {
		super(mx);
		this.invcov = invcov;
	}

	/**
	 * calculate mahalanubis distance
	 *
	 * @param features
	 *            amount of features shall correspond to amount dimensions
	 * @return calculated distance
	 */
	public double distance(double[] features) {
		// if we were invalidated, recalculate matrix
		if (invcov == null) {
			invcov = matrix();
		}
		// calculate mahalanobis distance
		double cumulated = 0;
		for (int i = 0; i < getDimensions(); i++) {
			double xmxc = 0;
			for (int j = 0; j < getDimensions(); j++) {
				xmxc += invcov[j][i] * (features[j] - center()[j]);
			}
			cumulated += xmxc * (features[i] - center()[i]);
		}

		// System.out.println("m cumulated:" + cumulated);
		// TODO: why it was negative producing NAN? Is there a mistake?
		// does not hurt absoluting it though...
		return Math.sqrt(Math.abs(cumulated));
	}

	/**
	 * gather samples - sum of x*y into matrix
	 *
	 * @param samples
	 *            samples belonging to cluster
	 */
	@Override
	public void train(double[] samples) {
		super.train(samples);
		// invalidate cumulated covariance
		invcov = null;
		if(sumxy==null)
			sumxy = new double[getDimensions()][getDimensions()];
		for (int i = 0; i < getDimensions(); i++)
			for (int j = 0; j < getDimensions(); j++) {
				sumxy[i][j] += samples[i] * samples[j];
			}
	}

	/**
	 * calculate covariance matrix and invert it
	 *
	 * @return
	 */
	double[][] matrix() {
		double cov[][] = new double[getDimensions()][getDimensions()];
		// System.out.println("covariance:");
		// StringBuilder var = new StringBuilder();
		for (int i = 0; i < getDimensions(); i++) {
			// var.append(getVar()[i]).append("\t");
			// StringBuilder sb = new StringBuilder();
			for (int j = 0; j < getDimensions(); j++) {

				cov[i][j] += sumxy[i][j] / getAmountSamples() - center()[i] * center()[j];

				// sb.append(cov[i][j]).append("\t");
			}
			// System.out.println(sb.toString());
		}
		// System.out.println("variance:");
		// System.out.println(var.toString());
		RealMatrix a = new Array2DRowRealMatrix(cov);
		DecompositionSolver solver = new LUDecomposition(a, Double.MIN_VALUE).getSolver();

		final RealMatrix inverse = solver.getInverse();
		// System.out.println("inverse:" + inverse);
		return inverse.getData();
	}

	public double[][] getInvcov() {
		return invcov;
	}

	public void setInvcov(double[][] invcov) {
		this.invcov = invcov;
	}

	@Override
	public Collection<? extends Instance> GenerateNewInstances(ChunkInfo chunkInfo, int numberOfSamples) {
		if (PostBalanceRatio > chunkInfo.previousRatio * (chunkInfo.chunkNumber-1)) {
			return PastPositiveInstances;
		} else {

			ArrayList<Instance> newInstances = new ArrayList<Instance>();

			if (numberOfSamples <= 0)
				return new ArrayList<Instance>();
			setDimensions(chunkInfo.chunk.get(0).numAttributes() - 1);

			for (Instance inst : chunkInfo.chunk) {
				if (inst.classValue() == MinorityClass) {
					newInstances.add(inst);
					train(Arrays.copyOf(inst.toDoubleArray(), inst.numAttributes() - 1));
				}
			}
			ArrayList<InstMahalanobisDist> closestSamples = new ArrayList<InstMahalanobisDist>();

			for (Instance inst : PastPositiveInstances) {
				closestSamples.add(new InstMahalanobisDist(inst, distance(Arrays.copyOf(inst.toDoubleArray(), inst.numAttributes() - 1))));
			}
			closestSamples.sort((o1, o2) -> o1.Distance.compareTo(o2.Distance));
			if(closestSamples.size() >= numberOfSamples)
				return closestSamples.subList(0, numberOfSamples).stream().map(sc -> sc.Inst).collect(Collectors.toList());
			else return closestSamples.stream().map(sc -> sc.Inst).collect(Collectors.toList());
		}
	}
}
